/*
 * Wazuh Vulnerability scanner - Scan Orchestrator
 * Copyright (C) 2015, Wazuh Inc.
 * May 1, 2023.
 *
 * This program is free software; you can redistribute it
 * and/or modify it under the terms of the GNU General Public
 * License (version 2) as published by the FSF - Free Software
 * Foundation.
 */

#ifndef _PACKAGE_SCANNER_HPP
#define _PACKAGE_SCANNER_HPP

#include "chainOfResponsability.hpp"
#include "databaseFeedManager.hpp"
#include "scanContext.hpp"
#include "scannerHelper.hpp"
#include "versionMatcher/versionMatcher.hpp"
#include <iostream>
#include <variant>

auto constexpr DEFAULT_CNA {"nvd"};

/**
 * @brief PackageScanner class.
 * This class is responsible for scanning the package and checking if it is vulnerable.
 * It receives the scan context and the database feed manager and returns the scan context with the vulnerability
 * details. The package format is used to determine the version object type or the version matcher strategy. The package
 * format can be deb, rpm, pypi, npm, pacman, snap, pkg, apk, win, macports. The vulnerability scan is performed using
 * the database feed manager.
 */
template<typename TDatabaseFeedManager = DatabaseFeedManager, typename TScanContext = ScanContext>
class TPackageScanner final : public AbstractHandler<std::shared_ptr<TScanContext>>
{
private:
    std::shared_ptr<TDatabaseFeedManager> m_databaseFeedManager;

    /**
     * @brief Package format to VersionObjectType / VersionMatcherStrategy map.
     *
     * @note The map is used to determine the version object type or the version matcher strategy based on the package
     * format.
     */
    std::unordered_map<std::string_view, std::variant<VersionObjectType, VersionMatcherStrategy>> m_packageMap {
        {"deb", VersionObjectType::DPKG},
        {"rpm", VersionObjectType::RPM},
        {"pypi", VersionObjectType::PEP440},
        {"npm", VersionObjectType::SemVer},
        {"pacman", VersionMatcherStrategy::Pacman},
        {"snap", VersionMatcherStrategy::Snap},
        {"pkg", VersionMatcherStrategy::PKG},
        {"apk", VersionMatcherStrategy::APK},
        {"win", VersionMatcherStrategy::Windows},
        {"macports", VersionMatcherStrategy::MacOS}};

public:
    // LCOV_EXCL_START
    /**
     * @brief PackageScanner constructor.
     *
     * @param databaseFeedManager Database feed manager.
     */
    explicit TPackageScanner(std::shared_ptr<TDatabaseFeedManager>& databaseFeedManager)
        : m_databaseFeedManager(databaseFeedManager)
    {
    }
    // LCOV_EXCL_STOP
    /**
     * @brief Handles request and passes control to the next step of the chain.
     *
     * @param data Scan context.
     * @return std::shared_ptr<TScanContext> Abstract handler.
     */
    std::shared_ptr<TScanContext> handleRequest(std::shared_ptr<TScanContext> data) override
    {
        const std::string packageName {Utils::toLowerCase(std::string(data->packageName()))};

        auto vulnerabilityScan =
            [&](const std::string& cnaName, const NSVulnerabilityScanner::ScanVulnerabilityCandidate& callbackData)
        {
            try
            {
                // if the platforms are not empty, we need to check if the platform is in the list.
                if (callbackData.platforms())
                {
                    bool matchPlatform {false};
                    for (const auto& platform : *callbackData.platforms())
                    {
                        const std::string platformValue {platform->str()};
                        // if the platform is a CPE, we need to parse it and check if the product is the same as the os
                        // cpe.
                        if (ScannerHelper::isCPE(platformValue))
                        {
                            const auto cpe {ScannerHelper::parseCPE(platformValue)};
                            if (cpe.part.compare("o") == 0)
                            {
                                if (ScannerHelper::compareCPE(cpe, ScannerHelper::parseCPE(data->osCPEName().data())))
                                {
                                    logDebug2(WM_VULNSCAN_LOGTAG,
                                              "The platform is in the list based on CPE comparison for "
                                              "Package: %s, Version: %s, CVE: %s, Content platform CPE: %s OS CPE: %s",
                                              packageName.c_str(),
                                              data->packageVersion().data(),
                                              callbackData.cveId()->str().c_str(),
                                              platformValue.c_str(),
                                              data->osCPEName().data());
                                    matchPlatform = true;
                                    break;
                                }
                            }
                        }
                        // If the platform is not a CPE, it is a string, at the moment, we only support the os code
                        // name. This is used mainly for debian and ubuntu platforms.
                        else
                        {
                            if (platformValue.compare(data->osCodeName()) == 0)
                            {
                                matchPlatform = true;
                                break;
                            }
                        }
                    }

                    if (!matchPlatform)
                    {
                        return false;
                    }
                }

                std::variant<VersionObjectType, VersionMatcherStrategy> objectType =
                    VersionMatcherStrategy::Unspecified;
                if (const auto it = m_packageMap.find(data->packageFormat()); it != m_packageMap.end())
                {
                    objectType = it->second;
                }

                for (const auto& version : *callbackData.versions())
                {
                    const std::string packageVersion {data->packageVersion()};
                    std::string versionString {version->version() ? version->version()->str() : ""};
                    std::string versionStringLessThan {version->lessThan() ? version->lessThan()->str() : ""};
                    std::string versionStringLessThanOrEqual {
                        version->lessThanOrEqual() ? version->lessThanOrEqual()->str() : ""};

                    logDebug2(WM_VULNSCAN_LOGTAG,
                              "Scanning package - '%s' (Installed Version: %s, Security Vulnerability: %s). Identified "
                              "vulnerability: "
                              "Version: %s. Required Version Threshold: %s. Required Version Threshold (or Equal): %s.",
                              packageName.c_str(),
                              packageVersion.c_str(),
                              callbackData.cveId()->str().c_str(),
                              versionString.c_str(),
                              versionStringLessThan.c_str(),
                              versionStringLessThanOrEqual.c_str());

                    // No version range specified, check if the installed version is equal to the required version.
                    if (versionStringLessThan.empty() && versionStringLessThanOrEqual.empty())
                    {
                        if (VersionMatcher::compare(packageVersion, versionString, objectType) ==
                            VersionComparisonResult::A_EQUAL_B)
                        {
                            // Version match found, the package status is defined by the vulnerability status.
                            if (version->status() == NSVulnerabilityScanner::Status::Status_affected)
                            {
                                logDebug1(
                                    WM_VULNSCAN_LOGTAG,
                                    "Match found, the package '%s', is vulnerable to '%s'. Current version: '%s' is "
                                    "equal to '%s'. - Agent '%s' (ID: '%s', Version: '%s').",
                                    packageName.c_str(),
                                    callbackData.cveId()->str().c_str(),
                                    packageVersion.c_str(),
                                    versionString.c_str(),
                                    data->agentName().data(),
                                    data->agentId().data(),
                                    data->agentVersion().data());

                                data->m_elements[callbackData.cveId()->str()] = nlohmann::json::object();
                                data->m_matchConditions[callbackData.cveId()->str()] = {versionString,
                                                                                        MatchRuleCondition::Equal};
                                return true;
                            }

                            return false;
                        }
                    }
                    else
                    {
                        // Version range specified

                        // Check if the installed version satisfies the lower bound of the version range.
                        auto lowerBoundMatch = false;
                        if (versionString.compare("0") == 0)
                        {
                            lowerBoundMatch = true;
                        }
                        else
                        {
                            const auto matchResult = VersionMatcher::compare(packageVersion, versionString, objectType);
                            lowerBoundMatch = matchResult == VersionComparisonResult::A_GREATER_THAN_B ||
                                              matchResult == VersionComparisonResult::A_EQUAL_B;
                        }

                        if (lowerBoundMatch)
                        {
                            // Check if the installed version satisfies the upper bound of the version range.
                            auto upperBoundMatch = false;
                            if (!versionStringLessThan.empty() && versionStringLessThan.compare("*") != 0)
                            {
                                const auto matchResult =
                                    VersionMatcher::compare(packageVersion, versionStringLessThan, objectType);
                                upperBoundMatch = matchResult == VersionComparisonResult::A_LESS_THAN_B;
                            }
                            else if (!versionStringLessThanOrEqual.empty())
                            {
                                const auto matchResult =
                                    VersionMatcher::compare(packageVersion, versionStringLessThanOrEqual, objectType);
                                upperBoundMatch = matchResult == VersionComparisonResult::A_LESS_THAN_B ||
                                                  matchResult == VersionComparisonResult::A_EQUAL_B;
                            }
                            else
                            {
                                upperBoundMatch = false;
                            }

                            if (upperBoundMatch)
                            {
                                // Version match found, the package status is defined by the vulnerability status.
                                if (version->status() == NSVulnerabilityScanner::Status::Status_affected)
                                {
                                    logDebug1(
                                        WM_VULNSCAN_LOGTAG,
                                        "Match found, the package '%s', is vulnerable to '%s'. Current version: "
                                        "'%s' ("
                                        "less than '%s' or equal to '%s'). - Agent '%s' (ID: '%s', Version: '%s').",
                                        packageName.c_str(),
                                        callbackData.cveId()->str().c_str(),
                                        packageVersion.c_str(),
                                        versionStringLessThan.c_str(),
                                        versionStringLessThanOrEqual.c_str(),
                                        data->agentName().data(),
                                        data->agentId().data(),
                                        data->agentVersion().data());

                                    data->m_elements[callbackData.cveId()->str()] = nlohmann::json::object();

                                    if (!versionStringLessThanOrEqual.empty())
                                    {
                                        data->m_matchConditions[callbackData.cveId()->str()] = {
                                            versionStringLessThanOrEqual, MatchRuleCondition::LessThanOrEqual};
                                    }
                                    else
                                    {
                                        data->m_matchConditions[callbackData.cveId()->str()] = {
                                            versionStringLessThan, MatchRuleCondition::LessThan};
                                    }
                                    return true;
                                }
                                else
                                {
                                    logDebug2(
                                        WM_VULNSCAN_LOGTAG,
                                        "No match due to default status for Package: %s, Version: %s while scanning "
                                        "for Vulnerability: %s, "
                                        "Installed Version: %s, Required Version Threshold: %s, Required Version "
                                        "Threshold (or Equal): %s",
                                        packageName.c_str(),
                                        packageVersion.c_str(),
                                        callbackData.cveId()->str().c_str(),
                                        versionString.c_str(),
                                        versionStringLessThan.c_str(),
                                        versionStringLessThanOrEqual.c_str());

                                    return false;
                                }
                            }
                        }
                    }
                }

                // No match found, the default status defines the package status.
                if (callbackData.defaultStatus() == NSVulnerabilityScanner::Status::Status_affected)
                {
                    logDebug1(WM_VULNSCAN_LOGTAG,
                              "Match found, the package '%s' is vulnerable to '%s' due to default status. - Agent "
                              "'%s' (ID: '%s', Version: '%s').",
                              packageName.c_str(),
                              callbackData.cveId()->str().c_str(),
                              data->agentName().data(),
                              data->agentId().data(),
                              data->agentVersion().data());

                    data->m_elements[callbackData.cveId()->str()] = nlohmann::json::object();
                    data->m_matchConditions[callbackData.cveId()->str()] = {"", MatchRuleCondition::DefaultStatus};
                    return true;
                }

                logDebug2(
                    WM_VULNSCAN_LOGTAG,
                    "No match due to default status for Package: %s, Version: %s while scanning for Vulnerability: %s",
                    packageName.c_str(),
                    data->packageVersion().data(),
                    callbackData.cveId()->str().c_str());

                return false;
            }
            catch (const std::exception& e)
            {
                // Log the warning and continue with the next vulnerability.
                logDebug1(WM_VULNSCAN_LOGTAG,
                          "Failed to scan package: '%s', CVE Numbering Authorities (CNA): '%s', Error: '%s'",
                          packageName.c_str(),
                          cnaName.c_str(),
                          e.what());

                return false;
            }
        };

        auto cnaName {m_databaseFeedManager->getCnaNameByFormat(data->packageFormat().data())};

        if (cnaName.empty())
        {
            cnaName = m_databaseFeedManager->getCnaNameByPrefix(data->packageVendor().data());
            if (cnaName.empty())
            {
                cnaName = m_databaseFeedManager->getCnaNameByContains(data->packageVendor().data());
                if (cnaName.empty())
                {
                    cnaName = DEFAULT_CNA;
                }
            }
        }

        logDebug1(WM_VULNSCAN_LOGTAG,
                  "Initiating a vulnerability scan for package '%s' (%s) (%s) with CVE Numbering Authorities (CNA) "
                  "'%s' on Agent "
                  "'%s' (ID: '%s', Version: '%s').",
                  packageName.c_str(),
                  data->packageFormat().data(),
                  data->packageVendor().data(),
                  cnaName.c_str(),
                  data->agentName().data(),
                  data->agentId().data(),
                  data->agentVersion().data());

        try
        {
            m_databaseFeedManager->getVulnerabilitiesCandidates(cnaName, packageName, vulnerabilityScan);
        }
        catch (const std::exception& e)
        {
            logWarn(WM_VULNSCAN_LOGTAG,
                    "Failed to scan package: '%s', CVE Numbering Authorities (CNA): '%s', Error: '%s'",
                    packageName.c_str(),
                    cnaName.c_str(),
                    e.what());
        }

        // Vulnerability scan ended for agent and package...
        logDebug1(WM_VULNSCAN_LOGTAG,
                  "Vulnerability scan for package '%s' on Agent '%s' has completed.",
                  packageName.c_str(),
                  data->agentId().data());

        if (data->m_elements.empty())
        {
            return nullptr;
        }
        return AbstractHandler<std::shared_ptr<TScanContext>>::handleRequest(std::move(data));
    }
};

using PackageScanner = TPackageScanner<>;

#endif // _PACKAGE_SCANNER_HPP
